# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
Updating MPS weights in multiprocess/distributed data collectors
================================================================

Overview of the Script
----------------------

This script demonstrates a weight update in TorchRL.
The script uses a custom `MPSWeightUpdateSender` class to update the weights of a policy network across multiple workers.

Key Features
------------

- Multi-Worker Setup: The script creates two worker processes that collect data from a Gym environment
  ("Pendulum-v1") using a policy network.
- MPS (Metal Performance Shaders) Device: The policy network is placed on an MPS device.
- Custom Weight Updater: The `MPSWeightUpdateSender` class is used to update the policy weights across workers. This
  class is necessary because MPS tensors cannot be sent over a pipe due to serialization/pickling issues in PyTorch.

Workaround for MPS Tensor Serialization Issue
---------------------------------------------

In PyTorch, MPS tensors cannot be serialized or pickled, which means they cannot be sent over a pipe or shared between
processes. To work around this issue, the MPSWeightUpdateSender class sends the policy weights on the CPU device
instead of the MPS device. The local workers then copy the weights from the CPU device to the MPS device.

Script Flow
-----------

1. Initialize the environment, policy network, and collector.
2. Update the policy weights using the MPSWeightUpdateSender.
3. Collect data from the environment using the policy network.
4. Zero out the policy weights after a few iterations.
5. Verify that the updated policy weights are being used by checking the actions generated by the policy network.

"""

import tensordict
import torch
from tensordict import TensorDictBase
from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.collectors import MultiSyncDataCollector, WeightUpdaterBase

from torchrl.envs.libs.gym import GymEnv


class MPSWeightUpdaterBase(WeightUpdaterBase):
    def __init__(self, policy_weights, num_workers):
        # Weights are on mps device, which cannot be shared
        self.policy_weights = policy_weights.data
        self.num_workers = num_workers

    def _sync_weights_with_worker(
        self, worker_id: int | torch.device, server_weights: TensorDictBase
    ) -> TensorDictBase:
        # Send weights on cpu - the local workers will do the cpu->mps copy
        self.collector.pipes[worker_id].send((server_weights, "update"))
        val, msg = self.collector.pipes[worker_id].recv()
        assert msg == "updated"
        return server_weights

    def _get_server_weights(self) -> TensorDictBase:
        print((self.policy_weights == 0).all())
        return self.policy_weights.cpu()

    def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase:
        print((server_weights == 0).all())
        return server_weights

    def all_worker_ids(self) -> list[int] | list[torch.device]:
        return list(range(self.num_workers))


if __name__ == "__main__":
    device = "mps"

    def env_maker():
        return GymEnv("Pendulum-v1", device="cpu")

    def policy_factory(device=device):
        return TensorDictModule(
            nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]
        ).to(device=device)

    policy = policy_factory()
    policy_weights = tensordict.from_module(policy)

    collector = MultiSyncDataCollector(
        create_env_fn=[env_maker, env_maker],
        policy_factory=policy_factory,
        total_frames=2000,
        max_frames_per_traj=50,
        frames_per_batch=200,
        init_random_frames=-1,
        reset_at_each_iter=False,
        device=device,
        storing_device="cpu",
        weight_updater=MPSWeightUpdaterBase(policy_weights, 2),
        # use_buffers=False,
        # cat_results="stack",
    )

    collector.update_policy_weights_()
    try:
        for i, data in enumerate(collector):
            if i == 2:
                print(data)
                assert (data["action"] != 0).any()
                # zero the policy
                policy_weights.data.zero_()
                collector.update_policy_weights_()
            elif i == 3:
                assert (data["action"] == 0).all(), data["action"]
                break
    finally:
        collector.shutdown()
